#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import os
import numpy as np
import cv2
import json
import sys
import struct
from collections import OrderedDict
import warnings

import torch
from torch.utils import data

# sys.path.insert(0, './scripts_internal/')
# import write_desc
from pytorch_jacinto_ai import xnn

###########################################
# config settings
def get_config():
    dataset_config = xnn.utils.ConfigNode()
    dataset_config.image_folders = ('leftImg8bit',)
    dataset_config.input_offsets = None
    dataset_config.load_segmentation = True
    dataset_config.use_semseg_for_depth = False
    dataset_config.load_segmentation_flow_correction = False
    return dataset_config


class TiadBaseSegmentationLoader():
    """CityscapesLoader: Data is derived from CityScapes, and can be downloaded from here: https://www.cityscapes-dataset.com/downloads/
    Many Thanks to @fvisin for the loader repo: https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py"""
    colors = [  # [  0,   0,   0],
        [152, 251, 152], [0, 130, 180], [220, 20, 60], [3, 3, 251], [190, 153, 153], [0, 0, 0]]  #[0, 130, 180](sky)  #[152, 251, 152](vegetation)  #[102, 102, 156](vehicle)

    label_colours = dict(zip(range(5), colors))

    void_classes = [-1, 255]
    valid_classes = [0, 1, 2, 3, 4]
    class_names = ['road', 'sky', 'pedestrian', 'vehicle', 'background']

    ignore_index = 255
    class_map = dict(zip(valid_classes, range(5)))
    num_classes_ = 5

    class_weights_ = np.array([0.30594229,  1., 25.07696964,  2.59353056,  0.38336123], dtype=float)
    # class_weights_ = np.ones(num_classes_)

    @classmethod
    def decode_segmap(cls, temp):
        r = temp.copy()
        g = temp.copy()
        b = temp.copy()
        for l in range(0, cls.num_classes_):
            r[temp == l] = cls.label_colours[l][0]
            g[temp == l] = cls.label_colours[l][1]
            b[temp == l] = cls.label_colours[l][2]

        rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
        rgb[:, :, 0] = r / 255.0
        rgb[:, :, 1] = g / 255.0
        rgb[:, :, 2] = b / 255.0
        return rgb


    @classmethod
    def encode_segmap(cls, mask):
        for _validc in cls.valid_classes:
            mask[mask == _validc] = cls.class_map[_validc]
        # Put all void classes to 255
        for _voidc in cls.void_classes:
            mask[mask == _voidc] = cls.ignore_index
        return mask


    @classmethod
    def class_weights(cls):
        return cls.class_weights_

class TiadBaseSemanticMotionLoader():
    """CityscapesLoader: Data is derived from CityScapes, and can be downloaded from here: https://www.cityscapes-dataset.com/downloads/
    Many Thanks to @fvisin for the loader repo: https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py"""
    colors = [  # [  0,   0,   0],
        [152, 251, 152], [0, 130, 180], [220, 20, 60], [3, 3, 251], [190, 153, 153], [220, 20, 60], [0, 0, 0]]  #[0, 130, 180](sky)  #[152, 251, 152](vegetation)  #[102, 102, 156](vehicle)

    label_colours = dict(zip(range(6), colors))

    void_classes = [-1, 255]
    valid_classes = [0, 1, 2, 3, 4, 5]
    class_names = ['road', 'sky', 'pedestrian', 'vehicle', 'background','vehicle_moving']

    ignore_index = 255
    class_map = dict(zip(valid_classes, range(6)))
    num_classes_ = 6

    class_weights_ = np.array([0.30594229,  1., 25.07696964,  5.18353056,  0.38336123,  5.18353056], dtype=float)

    @classmethod
    def decode_segmap(cls, temp):
        r = temp.copy()
        g = temp.copy()
        b = temp.copy()
        for l in range(0, cls.num_classes_):
            r[temp == l] = cls.label_colours[l][0]
            g[temp == l] = cls.label_colours[l][1]
            b[temp == l] = cls.label_colours[l][2]

        rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
        rgb[:, :, 0] = r / 255.0
        rgb[:, :, 1] = g / 255.0
        rgb[:, :, 2] = b / 255.0
        return rgb


    @classmethod
    def encode_segmap(cls, mask):
        for _validc in cls.valid_classes:
            mask[mask == _validc] = cls.class_map[_validc]
        # Put all void classes to 255
        for _voidc in cls.void_classes:
            mask[mask == _voidc] = cls.ignore_index
        return mask


    @classmethod
    def class_weights(cls):
        return cls.class_weights_



class TiadBaseMotionLoader():
    """CityscapesLoader: Data is derived from CityScapes, and can be downloaded from here: https://www.cityscapes-dataset.com/downloads/
    Many Thanks to @fvisin for the loader repo: https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py"""
    colors = [
        [0, 0, 0], [128, 64, 128], [0, 0, 0]]

    label_colours = dict(zip(range(2), colors))

    void_classes = [2]
    valid_classes = [0, 255] #255 #1
    class_names = ['static', 'moving']
    ignore_index = 255
    class_map = dict(zip(valid_classes, range(2)))
    num_classes_ = 2
    class_weights_ = np.array([0.05, 0.95], dtype=float)  # [ 0.51222399, 20.95158417]  #[0.02, 0.98]  #[0.05, 0.95]

    @classmethod
    def decode_segmap(cls, temp):
        r = temp.copy()
        g = temp.copy()
        b = temp.copy()
        for l in range(0, cls.num_classes_):
            r[temp == l] = cls.label_colours[l][0]
            g[temp == l] = cls.label_colours[l][1]
            b[temp == l] = cls.label_colours[l][2]
        #
        rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
        rgb[:, :, 0] = r / 255.0
        rgb[:, :, 1] = g / 255.0
        rgb[:, :, 2] = b / 255.0
        return rgb


    @classmethod
    def encode_segmap(cls, mask):
        # Put all void classes to zero
        for _validc in cls.valid_classes:
            mask[mask == _validc] = cls.class_map[_validc]
        for _voidc in cls.void_classes:
            mask[mask == _voidc] = cls.ignore_index
        return mask

    @classmethod
    def class_weights(cls):
        return cls.class_weights_


class TiadDataLoader(data.Dataset):
    def __init__(self, root, split="train", gt="gtFine", transforms=None, image_folders=('leftImg8bit',),
                 search_images=False, load_segmentation=True, load_segmentation_flow_correction=False, load_semantic_motion=False, load_depth=False, load_motion = False, load_flow=False, load_interest_pt = False,
                 inference=False, additional_info = False, start_offsets=None, end_offsets=None, input_offsets=None,
                 akaze_format='precomputed_bin', max_depth = 20.0, depth_scale = 1, use_semseg_for_depth=False, train_depth_log= False):
        super().__init__()
        if split not in ['train', 'val', 'test']:
            warnings.warn(f'unknown split specified: {split}')
        self.root = root if not isinstance(root, (tuple, list)) else root[-1]
        self.gt = gt
        self.split = split
        self.transforms = transforms
        self.image_folders = image_folders
        self.search_images = search_images
        self.files = {}
        self.additional_info = additional_info
        self.start_offsets = start_offsets
        self.end_offsets = end_offsets

        self.load_segmentation = load_segmentation
        self.load_segmentation_flow_correction = load_segmentation_flow_correction
        self.load_semantic_motion = load_semantic_motion
        self.load_depth = load_depth
        self.load_motion = load_motion
        self.load_flow = load_flow
        self.load_interest_pt = load_interest_pt
        self.num_interest_pt_channels = 65

        self.inference = inference
        self.input_offsets = input_offsets

        self.image_suffix = '.png'
        self.segmentation_suffix = self.gt + '_labelTrainIds.png'
        self.semantic_motion_suffix = self.gt + '_labelTrainIds_semantic_motion.png'
        self.depth_suffix = '.png'
        self.motion_suffix = self.gt+'_labelTrainIds_motion.png'
        self.set_index_for_zero_mean()
        self.akaze_format = akaze_format
        self.max_depth = max_depth
        self.depth_scale = depth_scale
        self.train_depth_log = train_depth_log
        self.use_semseg_for_depth = use_semseg_for_depth

        self.image_base = os.path.join(self.root, image_folders[-1], self.split) if not isinstance(root, (list, tuple)) else self.root
        self.segmentation_base = os.path.join(self.root, gt, self.split)
        self.semantic_motion_base = os.path.join(self.root, gt, self.split)
        self.depth_base = os.path.join(self.root, 'depth', self.split)
        self.ss_label_base = os.path.join(self.root, 'seg_labels', self.split)
        self.cameracalib_base = os.path.join(self.root, 'camera', self.split)
        self.motion_base = os.path.join(self.root, gt, self.split)

        if self.search_images:
            self.files = xnn.utils.recursive_glob(rootdir=self.image_base, suffix=self.image_suffix)
        elif self.load_segmentation:
            self.files = xnn.utils.recursive_glob(rootdir=self.segmentation_base, suffix=self.segmentation_suffix)
        elif self.load_semantic_motion:
            self.files = xnn.utils.recursive_glob(rootdir=self.semantic_motion_base, suffix=self.semantic_motion_suffix)
        elif self.load_motion:
            self.files = xnn.utils.recursive_glob(rootdir=self.motion_base, suffix=self.motion_suffix)
        elif self.load_depth:
            self.files = xnn.utils.recursive_glob(rootdir=self.depth_base, suffix=self.depth_suffix)
        #
        self.files = sorted(self.files)

        if not self.files:
            raise Exception("> No files for split=[%s] found in %s" % (split, self.segmentation_base))
        #

        self.image_files = [None] * len(image_folders)
        for image_idx, image_folder in enumerate(image_folders):
            image_base = os.path.join(self.root, image_folder, self.split) if not isinstance (root, (list, tuple)) else root[image_idx]
            self.image_files[image_idx] = sorted(xnn.utils.recursive_glob(rootdir=image_base, suffix='.png'))
            # select only the required files
            if (self.start_offsets and self.end_offsets):
                start_offset = self.start_offsets[image_idx]
                end_offset = self.end_offsets[image_idx]
                self.image_files[image_idx] = self.image_files[image_idx][start_offset:] \
                    if (end_offset == 0) else self.image_files[image_idx][start_offset:end_offset]
            #
            assert len(self.image_files[image_idx]) == len(self.image_files[0]), 'all folders should have same number of files'
        #
        if (self.start_offsets and self.end_offsets):
            start_offset = self.start_offsets[-1]
            end_offset = self.end_offsets[-1]
            self.files = self.files[start_offset:] if (end_offset == 0) else self.files[start_offset:end_offset]
        #

        # return this if we get a corrupt data
        if self.additional_info:
            self.first_images, self.first_targets, self.first_images_path, self.first_targets_path = self.__getitem__(0)
        else:
            self.first_imges, self.first_targets = self.__getitem__(0)

    def set_index_for_zero_mean(self):
        self.pos_chan_idx = np.asarray([3, 4, 7, 8, 11, 12, 15, 16, 19, 20, 23, 24, 27, 28, 31, 32, 35, 36, 39, 40, 43, 44, 47, 48, 51,
                        52, 55, 56, 59, 60, 63, 64])
        self.symmetric_chan_idx = np.asarray([1, 2, 5, 6, 9, 10, 13, 14, 17, 18, 21, 22, 25, 26, 29, 30, 33, 34, 37, 38, 41, 42, 45, 46,
                              49, 50, 53, 54, 57, 58, 61, 62])
        two_pos_for_loc = 2
        self.pos_chan_idx_in_data = self.pos_chan_idx + two_pos_for_loc
        self.symmetric_chan_idx_in_data = self.symmetric_chan_idx + two_pos_for_loc

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        if self.search_images:
            image_path = self.files[index].rstrip()
            self.check_file_exists(image_path)
            segmentation_path = image_path.replace(self.image_base, self.segmentation_base).replace(self.image_suffix, '_' + self.segmentation_suffix)
        elif self.load_segmentation or self.load_segmentation_flow_correction:
            segmentation_path = self.files[index].rstrip()
            self.check_file_exists(segmentation_path)
            image_path = segmentation_path.replace(self.segmentation_base, self.image_base).replace('_' + self.segmentation_suffix, self.image_suffix)
        elif self.load_semantic_motion:
            semantic_motion_path = self.files[index].rstrip()
            self.check_file_exists(semantic_motion_path)
            image_path = semantic_motion_path.replace(self.semantic_motion_base, self.image_base).replace('_' + self.semantic_motion_suffix, self.image_suffix)
        elif self.load_motion:
            motion_path = self.files[index].rstrip()
            self.check_file_exists(motion_path)
            image_path = motion_path.replace(self.motion_base, self.image_base).replace('_' + self.motion_suffix, self.image_suffix)
        elif self.load_depth:
            depth_path = self.files[index].rstrip()
            self.check_file_exists(depth_path)
            image_path = depth_path.replace(self.depth_base, self.image_base).replace('depth_', 'image_') #.replace(self.depth_base, self.image_base) #.replace('depth_', 'image_')
        elif self.load_interest_pt:
            image_path = self.files[index].rstrip()
            self.check_file_exists(image_path)
        #

        images = []
        images_path = []
        for image_idx, image_folder in enumerate(self.image_folders):
            this_image_path =  self.image_files[image_idx][index].rstrip()
            if image_idx == (len(self.image_folders)-1):
                assert this_image_path == image_path, 'image file name error'

            #
            self.check_file_exists(this_image_path)
            img_bgr = cv2.imread(this_image_path)
            img = img_bgr
            if img is None:
                # read failed. use a pre-stored (transformed) image
                return self.first_imges, self.first_targets
            else:
                img = img[:,:,::-1]
            #
            if self.input_offsets is not None:
                img = img - self.input_offsets[image_idx]
            #
            images.append(img)
            images_path.append(this_image_path)
        #

        targets = []
        targets_path = []
        if self.load_flow and (not self.inference):
            flow_zero = np.zeros((images[0].shape[0],images[0].shape[1],2), dtype=np.float32)
            targets.append(flow_zero)

        if self.load_depth and (not self.inference):
            depth_path = image_path.replace(self.image_base, self.depth_base).replace('image_','depth_')
            #print("self.ss_label_path:", self.ss_label_base)
            ss_label_path = ''
            if self.use_semseg_for_depth:
                ss_label_path = image_path.replace(self.image_base, self.ss_label_base).replace('image_','depth_')
                # print("ss_label_path:", ss_label_path)
                self.check_file_exists(ss_label_path)

            self.check_file_exists(depth_path)
            depth = self.depth_loader(depth_path, max_depth_afr_scale= self.max_depth, ss_label_path=ss_label_path, depth_scale = self.depth_scale,
              train_depth_log = self.train_depth_log)
            targets.append(depth)
        #

        if (self.load_segmentation or self.load_segmentation_flow_correction) and (not self.inference):
            lbl = cv2.imread(segmentation_path, 0).astype(np.uint8)
            lbl = TiadBaseSegmentationLoader.encode_segmap(np.array(lbl, dtype=np.uint8))
            targets.append(lbl) if self.load_segmentation else None
            targets_path.append(segmentation_path) if self.load_segmentation else None
            if self.load_segmentation_flow_correction:
                images[0][:,:,:-1][cv2.resize(lbl, dsize=(images[0].shape[1],images[0].shape[0]), interpolation=cv2.INTER_NEAREST) == 1] = 128
        #

        if self.load_semantic_motion and (not self.inference):
            lbl = cv2.imread(semantic_motion_path, 0).astype(np.uint8)
            lbl = TiadBaseSemanticMotionLoader.encode_segmap(np.array(lbl, dtype=np.uint8))
            targets.append(lbl)
            targets_path.append(semantic_motion_path)
        #

        if self.load_motion and (not self.inference):
            motion_path = image_path.replace(self.image_base, self.motion_base).replace(self.image_suffix, '_' + self.motion_suffix)
            self.check_file_exists(motion_path)
            motion = cv2.imread(motion_path, 0).astype(np.uint8)
            motion = TiadBaseMotionLoader.encode_segmap(np.array(motion, dtype=np.uint8))
            targets.append(motion)
        #

        if self.load_interest_pt and (not self.inference):
            interest_pt_descriptor = self.compute_interest_pt_descriptor(img_bgr, image_path)
            targets.append(interest_pt_descriptor)

        if (self.transforms is not None):
            images, targets = self.transforms(images, targets)
        #

        if self.additional_info:
            return images, targets, images_path, targets_path
        else:
            return images, targets
    #


    def decode_segmap(self, lbl):
        if self.load_segmentation:
            return TiadBaseSegmentationLoader.decode_segmap(lbl)
        elif self.load_semantic_motion:
            return TiadBaseSegmentationLoader.decode_segmap(lbl)
        else:
            return TiadBaseMotionLoader.decode_segmap(lbl)
    #


    def check_file_exists(self, file_name):
        if not os.path.exists(file_name) or not os.path.isfile(file_name):
            raise Exception("{} is not a file, can not open with imread.".format(file_name))
    #

    #stretch given array to min , max value
    def stretch_to_range(self, ip_array = [], min=1.0, max=255.0):
        ip_range = ip_array.max() - ip_array.min()
        new_range = max - min
        offset = min - (ip_array.min() * new_range / ip_range)
        op_array = (ip_array * new_range / ip_range) + offset
        return op_array 

    def depth_loader(self, depth_path, max_depth_afr_scale=20, ss_label_path = '', depth_scale=1.0, ignore_depth_val=-1,
      train_depth_log = False):
        depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        
        if depth.dtype == 'uint16':
            ignore_depth_val = ignore_depth_val if ignore_depth_val != -1 else np.iinfo(np.uint16).max
            depth[depth==ignore_depth_val] = 0
            #dataset = 'sc-sfm'
            #if dataset == 'sc-sfm':
                #depth[depth == (ignore_depth_val-1)] = 0

            #train log10(depth) 
            if train_depth_log:
                if not np.all(depth==0):
                    depth = self.stretch_to_range(ip_array = depth, min=1.0, max=255.0)
                    depth = np.log10(depth)
                    depth = self.stretch_to_range(ip_array = depth, min=0.0, max=255.0)
                else:
                    depth = np.float64(depth)
            else:    
                #depth/256.0 is to take care of scale
                #255.0/max_depth will make max_depth go to 255
                depth = (depth/256.0) * depth_scale

            #now depth from 0-max_depth has ben stretched to 0-255
            #max_depth = 255
            depth[depth>max_depth_afr_scale]=max_depth_afr_scale
        elif depth.dtype == 'uint8':
            ignore_depth_val = ignore_depth_val if ignore_depth_val != -1 else np.iinfo(np.uint8).max
            if depth_scale != 1:
                depth = np.uint8(depth * depth_scale)
            depth[depth==ignore_depth_val] = 0
            depth[depth>max_depth_afr_scale]=max_depth_afr_scale
        else:
            exit("float data type for depth GT?")   

        # read seg label for assigning max depth for far away region like Sky
        ss_label = cv2.imread(ss_label_path, cv2.IMREAD_UNCHANGED) if ss_label_path != '' else None
        if ss_label_path != '':
            sky_label = 1
            depth[ss_label == sky_label] = max_depth_afr_scale
        return depth
    #
    def num_classes(self):
        nc = []
        if self.load_flow:
            nc.append(2)
        if self.load_depth:
            nc.append(1)
        if self.load_segmentation:
            nc.append(TiadBaseSegmentationLoader.num_classes_)
        if self.load_semantic_motion:
            nc.append(TiadBaseSemanticMotionLoader.num_classes_)
        if self.load_motion:
            nc.append(TiadBaseMotionLoader.num_classes_)
        if self.load_interest_pt:
            nc.append(self.num_interest_pt_channels)
        #
        return nc
    #


    def class_weights(self):
        cw = []
        if self.load_depth:
            cw.append(None)
        if self.load_segmentation:
            cw.append(TiadBaseSegmentationLoader.class_weights())
        if self.load_semantic_motion:
            cw.append(TiadBaseSemanticMotionLoader.class_weights())
        if self.load_motion:
            cw.append(TiadBaseMotionLoader.class_weights())
        if self.load_interest_pt:
            cw.append(None)
        #
        return cw
    #

    def create_palette(self):
        palette = []
        if self.load_segmentation:
            palette.append(TiadBaseSegmentationLoader.colors)
        if self.load_semantic_motion:
            palette.append(TiadBaseSemanticMotionLoader.colors)
        if self.load_motion:
            palette.append(TiadBaseMotionLoader.colors)
        return palette

    def conv_desc_to_uniform_positive_range(self, descriptor=[], loc_x=0, loc_y=0, data=[], scale_score=127/0.001, scale_des=127):
        data_ary = np.asarray(data)
        #print(descriptor[loc_y, loc_x,:])
        descriptor[loc_y, loc_x, 0] = np.clip((data_ary[2]*scale_score),0.0,255.0)

        #actual range is 0-1.0 but very few sample above 0.5 so treat range as 0-0.5
        scale_des = 510.0
        #print(descriptor[loc_y, loc_x,:])
        descriptor[loc_y, loc_x, self.pos_chan_idx] = np.clip((data_ary[self.pos_chan_idx_in_data]*scale_des), 0.0, 255.0)

        #actual range is -1.0 to 1.0 but very few sample above 0.5 or below -0.5 so treat range as -0.5 to 0.5
        #print(descriptor[loc_y, loc_x,:])
        scale_des = 254.0
        descriptor[loc_y, loc_x, self.symmetric_chan_idx] = np.clip(data_ary[self.symmetric_chan_idx_in_data]*scale_des+128, 0.0, 255.0)
        #print(descriptor[loc_y, loc_x,:])

    def conv_desc_to_uniform_range(self, descriptor=[], loc_x=0, loc_y=0, data=[], scale_score=127/0.001, scale_des=127):
        data_ary = np.asarray(data)
        #print(descriptor[loc_y, loc_x,:])
        descriptor[loc_y, loc_x, 0] = np.clip((data_ary[2]*scale_score)-128.0,-128.0,127.0)

        scale_des = 127.0*4.0
        #print(descriptor[loc_y, loc_x,:])
        descriptor[loc_y, loc_x, self.pos_chan_idx] = np.clip((data_ary[self.pos_chan_idx_in_data]*scale_des)-128.0,-128.0,127.0)
        #print(descriptor[loc_y, loc_x,:])
        scale_des = 127.0*2.0
        descriptor[loc_y, loc_x, self.symmetric_chan_idx] = np.clip(data_ary[self.symmetric_chan_idx_in_data]*scale_des, -128.0,127.0)
        #print(descriptor[loc_y, loc_x,:])

    def read_akaze_score_desc_bin(self, image_name = [], img_shape = [], scale_score=127/0.001, scale_des=127, akaze_ds_fac = 1,
        akaze_params=[]):
        float_size = 4
        #loc_x,loc_y,score,desc[64]
        num_interest_pt_channels = 1+64
        num_el = 2+num_interest_pt_channels
        skip_fac = 2

        filename =  image_name.replace('.png', '.bin')
        file = open(filename, "rb")
        num_bytes = os.path.getsize(filename)

        tot_kp = num_bytes//(num_el*float_size)
        descriptor = np.zeros((img_shape[0]//akaze_ds_fac, img_shape[1]//akaze_ds_fac, num_interest_pt_channels), dtype=np.float32)
        #print("tot_kp: ", tot_kp)
        for index in range(0,tot_kp):
            data = struct.unpack('f'*num_el, file.read(float_size*num_el))
            #print(index)
            #print("# of key_pts :{}".format(len(kpts)))
            #print(data)
            loc_x = int(data[0])
            loc_y = int(data[1])

            if not (loc_x % skip_fac) and not (loc_y % skip_fac): 
                if akaze_params.learn_scaled_values_interest_pt:
                    if akaze_params.make_score_zero_mean:
                        self.conv_desc_to_uniform_range(descriptor=descriptor, loc_x=loc_x, loc_y=loc_y, data=data, scale_score=scale_score, scale_des=scale_des)
                    elif akaze_params.uniform_positive_range:
                        self.conv_desc_to_uniform_positive_range(descriptor=descriptor, loc_x=loc_x, loc_y=loc_y, data=data, scale_score=scale_score, scale_des=scale_des)
                    else:
                        descriptor[loc_y, loc_x, 0] = np.clip(data[2]*scale_score,0.0,255.0)  
                        descriptor[loc_y, loc_x, 1:] = np.asarray(data[3:])*scale_des

                else:      
                    descriptor[loc_y, loc_x, 0] = data[2]
                    descriptor[loc_y, loc_x, 1:] = np.asarray(data[3:])

        return descriptor    

    def set_akaze_params(self):  
      akaze_params = OrderedDict()
      
      akaze_params.learn_scaled_values_interest_pt = True
      #'precomputed_bin', 'precomputed_npy', 'compute', 'none'
      #akaze_params.akaze_format = 'none'
      akaze_params.make_score_zero_mean = False
      akaze_params.uniform_positive_range = True
      return akaze_params

    def compute_interest_pt_descriptor(self, img_bgr, image_path):
      akaze_params = self.set_akaze_params()
      
      akaze_save_np_format = False
      #in the format localization can directly consume
      akaze_save_localization_format = False

      #to speed up instead of orig res compute AKAZE at low res
      AKAZE_ON_LOW_RES = False
      akaze_ds_fac = 2 if AKAZE_ON_LOW_RES else 1
      scale_score = 127.0/0.001
      scale_des = 127.0
    
      if self.akaze_format == 'precomputed_npy':
        filename =  image_path.replace('.png', '.npz')
        npz_ary = np.load(filename)
        descriptor = npz_ary['arr_0'] 
        #print("descriptor.shape: ", descriptor.shape) 
      elif self.akaze_format == 'precomputed_bin':
        #bin format generated by KazeGTTraining utility
        descriptor = self.read_akaze_score_desc_bin(image_name = image_path, img_shape = img_bgr.shape, 
          scale_score=scale_score, scale_des=scale_des, akaze_ds_fac = akaze_ds_fac, 
          akaze_params = akaze_params)
        xnn.utils.comp_hist_tensor3d(x=descriptor, en = False, dir = 'gt', name='ch', log = True)
      elif self.akaze_format == 'compute':
        if AKAZE_ON_LOW_RES:
          descriptor = np.zeros((img_bgr.shape[0]//akaze_ds_fac, img_bgr.shape[1]//akaze_ds_fac, self.num_interest_pt_channels), dtype=np.float32)
          img_bgr = cv2.resize(img_bgr, (descriptor.shape[1], descriptor.shape[0]), interpolation=cv2.INTER_AREA)
        else:  
          descriptor = np.zeros((img_bgr.shape[0], img_bgr.shape[1], self.num_interest_pt_channels), dtype=np.float32)
        #print("img_bgr.shape :{}".format(img_bgr.shape))
        akaze_th = 0.0 #0.001
        akaze = cv2.AKAZE_create(descriptor_type=cv2.AKAZE_DESCRIPTOR_KAZE, threshold=akaze_th)
        kpts, descs = akaze.detectAndCompute(img_bgr, None)
        #print("# of key_pts :{}".format(len(kpts)))

        if akaze_save_localization_format:
            filename = image_path.replace('.png', '.txt')
            scale_to_write_kp_loc_to_orig_res = akaze_ds_fac if akaze_params.scale_to_write_kp_loc_to_orig_res == -1 else akaze_params.scale_to_write_kp_loc_to_orig_res
            # write_desc.write_immediate_score_desc_as_text(kpts=kpts, descs=descs, scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
            #                                               txt_file_name=filename, fract_loc = True)
        if len(kpts) > 0:
          for idx, (kpt,desc) in enumerate(zip(kpts, descs)):
            pt = np.round(kpt.pt)
            if akaze_params.learn_scaled_values_interest_pt:
              descriptor[int(pt[1]),int(pt[0]),0] = np.clip(kpt.response*scale_score,0.0,255.0)
              descriptor[int(pt[1]),int(pt[0]),1:] = desc*scale_des
            else:
              descriptor[int(pt[1]), int(pt[0]), 0] = np.clip(kpt.response, 0.0, 0.002)
              descriptor[int(pt[1]), int(pt[0]), 1:] = desc

        if akaze_save_np_format:
          filename =  image_path.replace('.png', '.npz')
          np.savez_compressed(filename, descriptor)
      else: #None
          descriptor = np.zeros((img_bgr.shape[0], img_bgr.shape[1], self.num_interest_pt_channels), dtype=np.float32)

      return descriptor


##########################################
def tiad_segmentation(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0])
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1])
        else:
            pass
    #
    return train_split, val_split



def tiad_depth(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_segmentation=False, load_depth = True,
                                         max_depth=dataset_config.max_depth_bfr_scaling, depth_scale=dataset_config.depth_scale,
                                         use_semseg_for_depth=dataset_config.use_semseg_for_depth, train_depth_log=dataset_config.train_depth_log)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_segmentation=False, load_depth = True,
                                       max_depth=dataset_config.max_depth_bfr_scaling, depth_scale=dataset_config.depth_scale,
                                       use_semseg_for_depth=dataset_config.use_semseg_for_depth, train_depth_log=dataset_config.train_depth_log)
        else:
            pass
    #
    return train_split, val_split


def tiad_interest_pt_descriptor(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], search_images=True,
                                         load_segmentation=False, load_interest_pt=True)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], search_images=True,
                                       load_segmentation=False, load_interest_pt=True)
        else:
            pass
    #
    return train_split, val_split



def tiad_depth_segmentation(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_depth = True)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_depth = True)
        else:
            pass
    #
    return train_split, val_split


##########################################
def tiad_motion_multi_input(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                           input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_segmentation=False, load_motion = True,
                                               image_folders=dataset_config.image_folders, input_offsets=input_offsets)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_segmentation=False, load_motion = True,
                                               image_folders=dataset_config.image_folders, input_offsets=input_offsets)
        else:
            pass
    #
    return train_split, val_split


##########################################
def tiad_semantic_motion_single_task_multi_input(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                           input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_segmentation=False, load_semantic_motion = True,
                                               image_folders=dataset_config.image_folders, input_offsets=input_offsets)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_segmentation=False, load_semantic_motion = True,
                                               image_folders=dataset_config.image_folders, input_offsets=input_offsets)
        else:
            pass
    #
    return train_split, val_split

##########################################
# --joint motion semantic training
def tiad_motion_semantic_multi_input(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'), input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_motion = True,
                                               image_folders=dataset_config.image_folders, input_offsets=input_offsets)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_motion = True,
                                               image_folders=dataset_config.image_folders, input_offsets=input_offsets)
        else:
            pass
    #
    return train_split, val_split


############################################
def tiad_depth_semantic_motion_multi_input(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                                                     input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_depth = True,
                                               load_motion=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets, max_depth=dataset_config.max_depth_bfr_scaling,
                                                train_depth_log= dataset_config.train_depth_log, use_semseg_for_depth=dataset_config.use_semseg_for_depth,
                                                load_segmentation_flow_correction=dataset_config.load_segmentation_flow_correction , depth_scale=dataset_config.depth_scale,)
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_depth = True,
                                             load_motion=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets, max_depth=dataset_config.max_depth_bfr_scaling,
                                              train_depth_log=dataset_config.train_depth_log,use_semseg_for_depth=dataset_config.use_semseg_for_depth,
                                                load_segmentation_flow_correction=dataset_config.load_segmentation_flow_correction , depth_scale=dataset_config.depth_scale,)
        else:
            pass
    #
    return train_split, val_split

############################################
def tiad_depth_semantic_motion_descriptor_multi_input(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                                                     input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']

    #akaze_format = 'compute'
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_depth = True,
                                               load_motion=True, load_interest_pt=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets
                                               )
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_depth = True,
                                             load_motion=True, load_interest_pt=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets
                                             )
        else:
            pass
    #
    return train_split, val_split

def tiad_semantic_motion_descriptor_multi_input(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                                                     input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    train_split = val_split = None
    split = ['train', 'val']

    #akaze_format = 'compute'
    for split_name in split:
        if split_name == 'train':
            train_split = TiadDataLoader(root, split_name, gt, transforms=transforms[0], load_depth = False,
                                               load_motion=True, load_interest_pt=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets
                                               )
        elif split_name == 'val':
            val_split = TiadDataLoader(root, split_name, gt, transforms=transforms[1], load_depth = False,
                                             load_motion=True, load_interest_pt=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets
                                             )
        else:
            pass
    #
    return train_split, val_split



#Semantic inference
def tiad_segmentation_infer(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    split_name = 'val'  #'train'  'val' , ''
    infer_split = TiadDataLoader(root, split_name, gt, transforms=transforms, image_folders=('leftImg8bit',),
                                 load_segmentation=True, search_images=True, inference=True, additional_info=True)
    return infer_split


def tiad_segmentation_measure(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    split_name = 'val'
    infer_split = TiadDataLoader(root, split_name, gt, transforms=transforms, image_folders=('leftImg8bit',),
                                 load_segmentation=True, search_images=True, inference=False, additional_info=True)
    return infer_split


def tiad_depth_infer(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    split_name = dataset_config.split_name  #'train' 'val'
    infer_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_segmentation=False, load_depth = True,
                                 search_images=True, inference=True, additional_info=True, max_depth=dataset_config.max_depth_bfr_scaling,
                                 depth_scale=dataset_config.depth_scale, use_semseg_for_depth=dataset_config.use_semseg_for_depth,
                                 train_depth_log=dataset_config.train_depth_log)
    #
    return infer_split


def tiad_depth_measure(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    split_name = 'val'  #'train' 'val'
    infer_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_segmentation=False, load_depth = True,
                                 search_images=True, inference=False, additional_info=True, max_depth=dataset_config.max_depth_bfr_scaling,
                                 depth_scale=dataset_config.depth_scale, train_depth_log=dataset_config.train_depth_log)
    #
    return infer_split


# motion inference
def tiad_motion_multi_input_infer(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious', 'leftImg8bit'), input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split_name = 'val'
    val_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_segmentation=False, load_motion = True,
                                           image_folders=dataset_config.image_folders, search_images=True, inference=True, additional_info=True, input_offsets=input_offsets)
    # --
    return val_split


def tiad_motion_multi_input_measure(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious', 'leftImg8bit'),input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split_name = 'val'   # 'val'  # 'train'
    val_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_segmentation=False, load_motion = True,
                                           image_folders=dataset_config.image_folders, search_images=True, inference=False, additional_info=True, input_offsets=input_offsets)
    # --
    return val_split


# single tak semantic motion measure
def tiad_semantic_motion_single_task_multi_input_measure(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious', 'leftImg8bit'), input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split_name = dataset_config.split_name
    val_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_segmentation=False, load_semantic_motion = True,
                                           image_folders=dataset_config.image_folders, additional_info=True, input_offsets=input_offsets)
    # --
    return val_split


# single tak semantic motion inference
def tiad_semantic_motion_single_task_multi_input_infer(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious', 'leftImg8bit'), input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split_name = dataset_config.split_name
    val_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_segmentation=False, load_semantic_motion = True,
                                           image_folders=dataset_config.image_folders, search_images=True, inference=True, additional_info=True, input_offsets=input_offsets)
    # --
    return val_split


# joint motion semantic inference
def tiad_motion_semantic_multi_input_infer(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'), input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split_name = 'val'
    val_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_motion=True,
                image_folders=dataset_config.image_folders, search_images=True, inference=True, additional_info=True, input_offsets=input_offsets)
    #
    return val_split



def tiad_motion_semantic_multi_input_measure(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'), input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split_name = 'val'
    val_split = TiadDataLoader(root, split_name, gt, transforms=transforms, load_motion=True,
                                               image_folders=dataset_config.image_folders, search_images=True, inference=False, additional_info=True, input_offsets=input_offsets)
    #
    return val_split


def tiad_depth_semantic_motion_multi_input_measure(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                                                     input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split = dataset_config.split_name
    val_split = TiadDataLoader(root, split, gt, transforms=transforms, load_depth = True, search_images=True,inference=False,additional_info=True,
                                               load_motion=True, image_folders=dataset_config.image_folders, input_offsets=input_offsets)
    #
    return val_split


############################################
def tiad_depth_semantic_motion_multi_input_infer(dataset_config, root, split=None, transforms=None, image_folders=('leftImg8bitPrevious','leftImg8bit'),
                                                     input_offsets=None):
    dataset_config = get_config().merge_from(dataset_config)
    gt = "gtFine"
    split = dataset_config.split_name
    val_split = TiadDataLoader(root, split, gt, transforms=transforms, load_depth = True, search_images=True,inference=True,additional_info=True,
                                               load_motion=True, image_folders=image_folders, input_offsets=input_offsets)
    #
    return val_split


#interest_pt inference
def tiad_interest_pt_infer(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    #split_name = 'val'  #'train'  'val'
    split_name = dataset_config.split_name
    infer_split = TiadDataLoader(root, split_name, gt, transforms=transforms, image_folders=('leftImg8bit',),
                                 load_segmentation=False, load_interest_pt=True, search_images=True, inference=True, additional_info=True)
    return infer_split


def tiad_interest_pt_measure(dataset_config, root, split=None, transforms=None):
    gt = "gtFine"
    #split_name = 'val'  #'train'  'val'
    split_name = dataset_config.split_name
    #'precomputed_bin', 'precomputed_npy', 'compute', 'none'
    akaze_format = 'none' if dataset_config.write_desc_type == 'PRED' else 'precomputed_bin'

    infer_split = TiadDataLoader(root, split_name, gt, transforms=transforms, image_folders=('leftImg8bit',),
                                 load_segmentation=False, load_interest_pt=True, search_images=True, inference=False,
                                 additional_info=True, akaze_format=akaze_format)
    return infer_split

